# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

sys.path.append("../")
import enum

from conv2d_common import (
    CommonConvFunction,
    CommonCutlassConvKernelDeclare,
    CommonCutlassConvKernelExecute,
    CommonTail,
    GenerateFunctionForPhi,
)
from util import SubstituteTemplate, TileDesc

# this is a file's header part

cba_header = '''
// Generated by conv2d_bias_act.py - Do not edit.

#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"


namespace phi {
namespace fusion {
namespace cutlass_internal {
'''

# This is a cutlass kernel, will be many these like kernels

dict_for_declare_part = {
    "conv_kind_name": "DefaultConv2dFprop",
    "epi_part": "${epi_func}< ${element_c}, ${epilogue_vector_length}, ${element_accum}, ${element_epilogue}>",
}

cba_kernel_no_alpha = (
    SubstituteTemplate(CommonCutlassConvKernelDeclare, dict_for_declare_part)
    + '''
  typename ImplicitGemm::Arguments arguments{
      problem_size,
      {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
      {(cutlass::half_t *)(weight), {kc, kc * kw, kc * kw * kh}},
      {(cutlass::half_t *)(bias), {0, 0, 0}},
      {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
      {1.f, 1.f}};
'''
    + CommonCutlassConvKernelExecute
)

# This is used for leaky_relu, this activation need a fuse_alpha parameter.

cba_kernel_alpha = cba_kernel_no_alpha.replace(
    "{1.f, 1.f}", "{1.f, 1.f, alpha}"
).replace(
    "typename ImplicitG", "float alpha = params.alpha; typename ImplicitG"
)


class CbaAct(enum.Enum):
    Identity = 1
    Relu = 2
    Silu = 3
    LeakyRelu = 4
    Sigmoid = 5


# Some global variables used, now we only support these activations.
SupportedAct = [
    CbaAct.Identity,
    CbaAct.Relu,
    CbaAct.Silu,
    CbaAct.LeakyRelu,
    CbaAct.Sigmoid,
]

ActTag = {
    SupportedAct[0]: 'cutlass::epilogue::thread::LinearCombination',
    SupportedAct[1]: 'cutlass::epilogue::thread::LinearCombinationRelu',
    SupportedAct[2]: 'cutlass::epilogue::thread::LinearCombinationSilu',
    SupportedAct[3]: 'cutlass::epilogue::thread::LinearCombinationLeakyRelu',
    SupportedAct[4]: 'cutlass::epilogue::thread::LinearCombinationSigmoid',
}

UnderScoreName = {
    SupportedAct[0]: "conv2d_bias",
    SupportedAct[1]: "conv2d_bias_relu",
    SupportedAct[2]: "conv2d_bias_silu",
    SupportedAct[3]: "conv2d_bias_leaky_relu",
    SupportedAct[4]: "conv2d_bias_sigmoid",
}

CamelName = {
    SupportedAct[0]: "Conv2dBias",
    SupportedAct[1]: "Conv2dBiasRelu",
    SupportedAct[2]: "Conv2dBiasSilu",
    SupportedAct[3]: "Conv2dBiasLeakyRelu",
    SupportedAct[4]: "Conv2dBiasSigmoid",
}

# Generate sm75 TensorOp conv code.
# CUTLASS Tensor Core operations are implemented using CUDA's mma instruction.
# Here is mma.m16n8k8.


def generate_sm75_1688():
    kernel_dict = {
        "element_a": "cutlass::half_t",
        "layout_a": "cutlass::layout::TensorNHWC",
        "element_b": "cutlass::half_t",
        "layout_b": "cutlass::layout::TensorNHWC",
        "element_c": "cutlass::half_t",
        "layout_c": "cutlass::layout::TensorNHWC",
        "opcode_class": "cutlass::arch::OpClassTensorOp",
        "arch": "cutlass::arch::Sm75",
        "stages": "2",
        "swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
        # alpha is always float!
        "element_epilogue": "float",
        "math_operator": "cutlass::arch::OpMultiplyAdd",
    }

    kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided"

    # iterate over this loop
    element_accums = ["cutlass::half_t", "float"]
    iterator_algorithms = [
        "cutlass::conv::IteratorAlgorithm::kOptimized",
        # "cutlass::conv::IteratorAlgorithm::kAnalytic",
    ]

    math_instructions = [
        # (
        #     "16,8,8",
        #     "cutlass::half_t",
        #     "cutlass::half_t",
        #     "cutlass::half_t",
        # ),
        (
            "16,8,8",
            "cutlass::half_t",
            "cutlass::half_t",
            "float",
        ),
    ]

    alignments = [8]

    kernel_dict["align_a"] = "8"
    kernel_dict["align_b"] = "8"
    # this should divided by oc
    kernel_dict["epilogue_vector_length"] = "8"
    kernel_dict["split_k_slices"] = "1"

    sm75_code = ""
    for epi_func in SupportedAct:
        op_dict = {}
        op_dict["func_name"] = UnderScoreName[epi_func].lower() + "_sm75"
        op_dict["enum_op_name"] = UnderScoreName[epi_func].upper()
        # For a function, we record all its kernels into a std::vector in C++ code
        all_kernel_names = ""
        kernel_dict["epi_func"] = ActTag[epi_func]
        suffix = 0
        for iterator_algorithm in iterator_algorithms:
            for alignment in alignments:
                for math_inst in math_instructions:
                    tiles = [
                        TileDesc("64, 64, 64", 2, "32, 32, 64", math_inst),
                        TileDesc("64, 32, 64", 2, "32, 32, 64", math_inst),
                        TileDesc("128, 32, 64", 2, "32, 32, 64", math_inst),
                        TileDesc("128, 64, 64", 2, "32, 32, 64", math_inst),
                        TileDesc("64, 64, 32", 2, "32, 32, 32", math_inst),
                        TileDesc("64, 128, 32", 2, "32, 64, 32", math_inst),
                        TileDesc("64, 128, 64", 2, "64, 64, 32", math_inst),
                        TileDesc("64, 256, 32", 2, "64, 64, 32", math_inst),
                        TileDesc("128, 64, 32", 2, "64, 32, 32", math_inst),
                    ]
                    for tile in tiles:
                        kernel_dict["iterator_algorithm"] = iterator_algorithm
                        kernel_dict["Tshape"] = tile.Tshape
                        kernel_dict["Wshape"] = tile.Wshape
                        kernel_dict["Ishape"] = tile.math_inst[0]
                        kernel_dict["element_accum"] = tile.math_inst[3]
                        kernel_dict["kernel_func_name"] = op_dict[
                            "func_name"
                        ] + str(suffix)
                        suffix += 1
                        cba_kernel = cba_kernel_no_alpha
                        if epi_func in [CbaAct.LeakyRelu]:
                            cba_kernel = cba_kernel_alpha
                        sm75_code += SubstituteTemplate(cba_kernel, kernel_dict)
                        all_kernel_names += (
                            kernel_dict["kernel_func_name"] + ", \n"
                        )

        # Generate op code
        op_dict["all_kernel_func_name"] = all_kernel_names
        sm75_code += SubstituteTemplate(CommonConvFunction, op_dict)
    return sm75_code


if __name__ == "__main__":
    sm_versions = ["75"]
    all_code = cba_header
    all_code += generate_sm75_1688()
    all_code += GenerateFunctionForPhi(
        sm_versions, SupportedAct, UnderScoreName, CamelName
    )
    all_code += CommonTail
    with open("generated/conv2d_bias_act.cu", "w") as f:
        f.write(all_code)
        f.close()
